{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Pmap Cookbook",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LpPtl0n4rg6L",
        "colab_type": "text"
      },
      "source": [
        "# Colab JAX TPU Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4DYY4Yyhq8vG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import jax.tools.colab_tpu\n",
        "jax.tools.colab_tpu.setup_tpu()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_4ware9HrjIk",
        "colab_type": "text"
      },
      "source": [
        "# Pmap CookBook"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sk-3cPGIBTq8",
        "colab_type": "text"
      },
      "source": [
        "This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n",
        "\n",
        "To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine or a Cloud TPU.\n",
        "\n",
        "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Srs8W9F6Jo15",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import jax.numpy as jnp"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBasY8p1JFId",
        "colab_type": "text"
      },
      "source": [
        "## Basics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "caPiPIWgM7-W",
        "colab_type": "text"
      },
      "source": [
        "### Pure maps, with no communication"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2e_06-OAJNyi",
        "colab_type": "text"
      },
      "source": [
        "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6gGT77cIImcE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax import pmap"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-yY3lOFpJIUS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "result = pmap(lambda x: x ** 2)(jnp.arange(7))\n",
        "print(result)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PgKNzxKPNEYA",
        "colab_type": "text"
      },
      "source": [
        "In terms of what values are computed, `pmap` is similar to `vmap` in that it transforms a function to map over an array axis:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mmCMQ64QbAbz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax import vmap\n",
        "\n",
        "x = jnp.array([1., 2., 3.])\n",
        "y = jnp.array([2., 4., 6.])\n",
        "\n",
        "print(vmap(jnp.add)(x, y))\n",
        "print(pmap(jnp.add)(x, y))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iZgTmx5pFd6z",
        "colab_type": "text"
      },
      "source": [
        "But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4N1--GgGFe9d",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax import make_jaxpr\n",
        "\n",
        "def f(x, y):\n",
        "  a = jnp.dot(x, y)\n",
        "  b = jnp.tanh(a)\n",
        "  return b\n",
        "\n",
        "xs = jnp.ones((8, 2, 3))\n",
        "ys = jnp.ones((8, 3, 4))\n",
        "\n",
        "print(\"f jaxpr\")\n",
        "print(make_jaxpr(f)(xs[0], ys[0]))\n",
        "\n",
        "print(\"vmap(f) jaxpr\")\n",
        "print(make_jaxpr(vmap(f))(xs, ys))\n",
        "\n",
        "print(\"pmap(f) jaxpr\")\n",
        "print(make_jaxpr(pmap(f))(xs, ys))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BjDnQkzSa_vZ",
        "colab_type": "text"
      },
      "source":  [
        "Notice that applying `vmap(f)` to these arguments leads to a `dot_general` to express the batch matrix multiplication in a single primitive, while applying `pmap(f)` instead leads to a primitive that calls replicas of the original `f` in parallel.\n",
        "\n",
        "An important constraint with using `pmap` is that ",        
        "the mapped axis size must be less than or equal to the number of XLA devices available (and for nested `pmap` functions, the product of the mapped axis sizes must be less than or equal to the number of XLA devices).\n",
        "\n",
        "You can use the output of a `pmap` function just like any other value:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H4DXQWobOf7V",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
        "z = y / 2\n",
        "print(z)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fM1Une9Rfqld",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "plt.plot(y)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "644UB23YfbW4",
        "colab_type": "text"
      },
      "source": [
        "But while the output here acts just like a NumPy ndarray, if you look closely it has a different type:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "59hnyVOtfavX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "y"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4brdSdeyf2MP",
        "colab_type": "text"
      },
      "source": [
        "A `ShardedDeviceArray` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
        "\n",
        "When you call a non-`pmap` function on a `ShardedDeviceArray`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BSSllkblg9Rn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "y / 2"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "efyMSNGahq6f",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "np.sin(y)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ba4jwfkbOwXW",
        "colab_type": "text"
      },
      "source": [
        "Thinking about device memory is important to maximize performance by avoiding data transfers, but you can always fall back to treating arraylike values as (read-only) NumPy ndarrays and your code will still work.\n",
        "\n",
        "Here's another example of a pure map which makes better use of our multiple-accelerator resources. We can generate several large random matrices in parallel, then perform parallel batch matrix multiplication without any cross-device movement of the large matrix data:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rWl68coLJSi7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax import random\n",
        "\n",
        "# create 8 random keys\n",
        "keys = random.split(random.PRNGKey(0), 8)\n",
        "# create a 5000 x 6000 matrix on each device by mapping over keys\n",
        "mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n",
        "# the stack of matrices is represented logically as a single array\n",
        "mats.shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nH2gGNgfNOJD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# run a local matmul on each device in parallel (no data transfer)\n",
        "result = pmap(lambda x: jnp.dot(x, x.T))(mats)\n",
        "result.shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MKTZ59iPNPi5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# compute the mean on each device in parallel and print the results\n",
        "print(pmap(jnp.mean)(result))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "26iH7sHFiz2l",
        "colab_type": "text"
      },
      "source": [
        "In this example, the large matrices never had to be moved between devices or back to the host; only one scalar per device was pulled back to the host."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MdRscR5MONuN",
        "colab_type": "text"
      },
      "source": [
        "### Collective communication operations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bFtajUwp5WYx",
        "colab_type": "text"
      },
      "source": [
        "In addition to expressing pure maps, where no communication happens between the replicated functions, with `pmap` you can also use special collective communication operations.\n",
        "\n",
        "One canonical example of a collective, implemented on both GPU and TPU, is an all-reduce sum like `lax.psum`:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d5s8rJVUORQ3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax import lax\n",
        "\n",
        "normalize = lambda x: x / lax.psum(x, axis_name='i')\n",
        "result = pmap(normalize, axis_name='i')(jnp.arange(4.))\n",
        "print(result)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6jd9DVBQPD-Z",
        "colab_type": "text"
      },
      "source": [
        "To use a collective operation like `lax.psum`, you need to supply an `axis_name` argument to `pmap`. The `axis_name` argument associates a name to the mapped axis so that collective operations can refer to it.\n",
        "\n",
        "Another way to write this same code is to use `pmap` as a decorator:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c48qVvlkPF5p",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from functools import partial\n",
        "\n",
        "@partial(pmap, axis_name='i')\n",
        "def normalize(x):\n",
        "  return x / lax.psum(x, 'i')\n",
        "\n",
        "print(normalize(jnp.arange(4.)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Pr6n8KkOpmz",
        "colab_type": "text"
      },
      "source": [
        "Axis names are also important for nested use of `pmap`, where collectives can be applied to distinct mapped axes:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IwoeEd16OrD3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@partial(pmap, axis_name='rows')\n",
        "@partial(pmap, axis_name='cols')\n",
        "def f(x):\n",
        "  row_normed = x / lax.psum(x, 'rows')\n",
        "  col_normed = x / lax.psum(x, 'cols')\n",
        "  doubly_normed = x / lax.psum(x, ('rows', 'cols'))\n",
        "  return row_normed, col_normed, doubly_normed\n",
        "\n",
        "x = jnp.arange(8.).reshape((4, 2))\n",
        "a, b, c = f(x)\n",
        "\n",
        "print(a)\n",
        "print(a.sum(0))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bnc-vlKA6hvI",
        "colab_type": "text"
      },
      "source": [
        "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n",
        "\n",
        "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n",
        "\n",
        "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uazGbMwmf5zO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax.lib import xla_bridge\n",
        "device_count = xla_bridge.device_count()\n",
        "\n",
        "def send_right(x, axis_name):\n",
        "  left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]\n",
        "  return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n",
        "\n",
        "def send_left(x, axis_name):\n",
        "  left_perm = [((i + 1) % device_count, i) for i in range(device_count)]\n",
        "  return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n",
        "\n",
        "def update_board(board):\n",
        "  left = board[:-2]\n",
        "  right = board[2:]\n",
        "  center = board[1:-1]\n",
        "  return lax.bitwise_xor(left, lax.bitwise_or(center, right))\n",
        "\n",
        "@partial(pmap, axis_name='i')\n",
        "def step(board_slice):\n",
        "  left, right = board_slice[:1], board_slice[-1:]\n",
        "  right, left = send_left(left, 'i'), send_right(right, 'i')\n",
        "  enlarged_board_slice = jnp.concatenate([left, board_slice, right])\n",
        "  return update_board(enlarged_board_slice)\n",
        "\n",
        "def print_board(board):\n",
        "  print(''.join('*' if x else ' ' for x in board.ravel()))\n",
        "\n",
        "\n",
        "board = np.zeros(40, dtype=bool)\n",
        "board[board.shape[0] // 2] = True\n",
        "reshaped_board = board.reshape((device_count, -1))\n",
        "\n",
        "print_board(reshaped_board)\n",
        "for _ in range(20):\n",
        "  reshaped_board = step(reshaped_board)\n",
        "  print_board(reshaped_board)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KrkEuY3yO7_M",
        "colab_type": "text"
      },
      "source": [
        "## Composing with differentiation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dGHE7dfypqqU",
        "colab_type": "text"
      },
      "source": [
        "As with all things in JAX, you should expect `pmap` to compose with other transformations, including differentiation."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VkS7_RcTO_48",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from jax import grad\n",
        "\n",
        "@pmap\n",
        "def f(x):\n",
        "  y = jnp.sin(x)\n",
        "  @pmap\n",
        "  def g(z):\n",
        "    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
        "  return grad(lambda w: jnp.sum(g(w)))(x)\n",
        "  \n",
        "f(x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4gAJ3QF6PBvi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "grad(lambda x: jnp.sum(f(x)))(x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8mAz9bEfPl2F",
        "colab_type": "text"
      },
      "source": [
        "When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the backward pass of the computation is parallelized just like the forward-pass."
      ]
    }
  ]
}
